# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from __future__ import annotations

import argparse
import time
import warnings

import pytest
import torch
from tensordict import LazyStackedTensorDict, TensorDict
from torch import multiprocessing as mp
from torchrl._utils import logger as torchrl_logger


class TestShared:
    @staticmethod
    def remote_process(command_pipe_child, command_pipe_parent, tensordict):
        command_pipe_parent.close()
        assert tensordict.is_shared()
        t0 = time.time()
        tensordict.zero_()
        torchrl_logger.info(f"zeroing time: {time.time() - t0}")
        command_pipe_child.send("done")
        command_pipe_child.close()
        del command_pipe_child, command_pipe_parent, tensordict

    @staticmethod
    def driver_func(subtd, td):
        assert subtd.is_shared()
        command_pipe_parent, command_pipe_child = mp.Pipe()
        proc = mp.Process(
            target=TestShared.remote_process,
            args=(command_pipe_child, command_pipe_parent, subtd),
        )
        proc.start()
        command_pipe_child.close()
        command_pipe_parent.recv()
        for item in subtd.values():
            assert (item == 0).all()

        for item in td[0].values():
            assert (item == 0).all()
        command_pipe_parent.close()
        proc.join()
        del command_pipe_child, command_pipe_parent, proc

    @pytest.mark.parametrize("indexing_method", range(3))
    def test_shared(self, indexing_method):
        torch.manual_seed(0)
        tensordict = TensorDict(
            source={
                "a": torch.randn(1000, 200),
                "b": torch.randn(1000, 100),
                "done": torch.zeros(1000, 100, dtype=torch.bool).bernoulli_(),
            },
            batch_size=[1000],
        )

        td = tensordict.clone().share_memory_()
        if indexing_method == 0:
            subtd = TensorDict(
                source={key: item[0] for key, item in td.items()},
                batch_size=[],
            ).share_memory_()
        elif indexing_method == 1:
            subtd = td._get_sub_tensordict(0)
        elif indexing_method == 2:
            subtd = td[0]
        else:
            raise NotImplementedError

        assert subtd.is_shared()

        self.driver_func(subtd, td)


class TestStack:
    @staticmethod
    def remote_process(command_pipe_child, command_pipe_parent, tensordict):
        command_pipe_parent.close()
        assert isinstance(tensordict, TensorDict), f"td is of type {type(tensordict)}"
        assert tensordict.is_shared() or tensordict.is_memmap()
        new_tensordict = LazyStackedTensorDict.lazy_stack(
            [
                tensordict[i].contiguous().clone().zero_()
                for i in range(tensordict.shape[0])
            ],
            0,
        )
        cmd = command_pipe_child.recv()
        t0 = time.time()
        if cmd == "stack":
            tensordict.copy_(new_tensordict)
        elif cmd == "serial":
            for i, td in enumerate(new_tensordict.tensordicts):
                tensordict.update_at_(td, i)
        time_spent = time.time() - t0
        command_pipe_child.send(time_spent)
        command_pipe_child.close()
        del command_pipe_child, command_pipe_parent

    @staticmethod
    def driver_func(td, stack):

        command_pipe_parent, command_pipe_child = mp.Pipe()
        proc = mp.Process(
            target=TestStack.remote_process,
            args=(command_pipe_child, command_pipe_parent, td),
        )
        proc.start()
        command_pipe_child.close()
        command_pipe_parent.send("stack" if stack else "serial")
        time_spent = command_pipe_parent.recv()
        torchrl_logger.info(f"stack {stack}: time={time_spent}")
        for item in td.values():
            assert (item == 0).all()
        proc.join()
        command_pipe_parent.close()
        return time_spent

    @pytest.mark.parametrize("shared", ["shared", "memmap"])
    def test_shared(self, shared):
        torchrl_logger.info(f"test_shared: shared={shared}")
        torch.manual_seed(0)
        tensordict = TensorDict(
            source={
                "a": torch.randn(100, 2),
                "b": torch.randn(100, 1),
                "done": torch.zeros(100, 1, dtype=torch.bool).bernoulli_(),
            },
            batch_size=[100],
        )
        if shared == "shared":
            tensordict.share_memory_()
        else:
            tensordict.memmap_()
        t_true = self.driver_func(tensordict, True)
        t_false = self.driver_func(tensordict, False)
        if t_true > t_false:
            warnings.warn(
                "Updating each element of the tensordict did "
                "not take longer than updating the stack."
            )


@pytest.mark.parametrize("idx", [0, slice(200)])
@pytest.mark.parametrize("dtype", [torch.float, torch.bool])
def test_memmap(idx, dtype, large_scale=False):
    N = 5000 if large_scale else 10
    H = 128 if large_scale else 8
    td = TensorDict(
        source={
            "a": torch.zeros(N, 3, H, H, dtype=dtype),
            "b": torch.zeros(N, 3, H, H, dtype=dtype),
            "c": torch.zeros(N, 3, H, H, dtype=dtype),
        },
        batch_size=[
            N,
        ],
    )

    td_sm = td.clone().share_memory_()
    td_memmap = td.clone().memmap_()

    torchrl_logger.info("\nTesting reading from TD")
    for i in range(2):
        t0 = time.time()
        td_sm[idx].clone()
        if i == 1:
            torchrl_logger.info(f"sm: {time.time() - t0:4.4f} sec")

        t0 = time.time()
        td_memmap[idx].clone()
        if i == 1:
            torchrl_logger.info(f"memmap: {time.time() - t0:4.4f} sec")

    td_to_copy = td[idx].contiguous()
    for k in td_to_copy.keys():
        td_to_copy.set_(k, torch.ones_like(td_to_copy.get(k)))

    torchrl_logger.info("\nTesting writing to TD")
    for i in range(2):
        t0 = time.time()
        sub_td_sm = td_sm._get_sub_tensordict(idx)
        sub_td_sm.update_(td_to_copy)
        if i == 1:
            torchrl_logger.info(f"sm td: {time.time() - t0:4.4f} sec")
        torch.testing.assert_close(sub_td_sm.get("a"), td_to_copy.get("a"))

        t0 = time.time()
        sub_td_sm = td_memmap._get_sub_tensordict(idx)
        sub_td_sm.update_(td_to_copy)
        if i == 1:
            torchrl_logger.info(f"memmap td: {time.time() - t0:4.4f} sec")
        torch.testing.assert_close(sub_td_sm.get("a")._tensor, td_to_copy.get("a"))


if __name__ == "__main__":
    args, unknown = argparse.ArgumentParser().parse_known_args()
    pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)
